﻿/*
 * Copyright 2016 The Cartographer Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "slam/mapping/internal/2d/scan_matching/fast_correlative_scan_matcher_2d.h"

#include <algorithm>
#include <cmath>
#include <deque>
#include <functional>
#include <limits>

#include "Eigen/Geometry"
#include "absl/memory/memory.h"
#include "glog/logging.h"
#include "slam/common/math.h"
#include "slam/mapping/2d/grid_2d.h"
#include "slam/mapping/internal/constraints/constraint_builder_2d.h"
#include "slam/sensor/point_cloud.h"
#include "slam/transform/transform.h"

using namespace slam::mapping;
namespace slam {
namespace mapping {
namespace scan_matching {
namespace {

// A collection of values which can be added and later removed, and the maximum
// of the current values in the collection can be retrieved.
// All of it in (amortized) O(1).
class SlidingWindowMaximum {
 public:
  void AddValue(const float value) {
    while (!non_ascending_maxima_.empty() &&
           value > non_ascending_maxima_.back()) {
      non_ascending_maxima_.pop_back();
    }
    non_ascending_maxima_.push_back(value);
  }

  void RemoveValue(const float value) {
    // DCHECK for performance, since this is done for every value in the
    // precomputation grid.
    DCHECK(!non_ascending_maxima_.empty());
    DCHECK_LE(value, non_ascending_maxima_.front());
    if (value == non_ascending_maxima_.front()) {
      non_ascending_maxima_.pop_front();
    }
  }

  float GetMaximum() const {
    // DCHECK for performance, since this is done for every value in the
    // precomputation grid.
    DCHECK_GT(non_ascending_maxima_.size(), 0);
    return non_ascending_maxima_.front();
  }

  void CheckIsEmpty() const { CHECK_EQ(non_ascending_maxima_.size(), 0); }

 private:
  // Maximum of the current sliding window at the front. Then the maximum of the
  // remaining window that came after this values first occurrence, and so on.
  std::deque<float> non_ascending_maxima_;
};

}  // namespace

proto::FastCorrelativeScanMatcherOptions2D
CreateFastCorrelativeScanMatcherOptions2D(
    common::LuaParameterDictionary* const parameter_dictionary) {
  proto::FastCorrelativeScanMatcherOptions2D options;
  options.set_linear_search_window(
      parameter_dictionary->GetDouble("linear_search_window"));
  options.set_angular_search_window(
      parameter_dictionary->GetDouble("angular_search_window"));
  options.set_branch_and_bound_depth(
      parameter_dictionary->GetInt("branch_and_bound_depth"));

  if (parameter_dictionary->HasKey("use_dynamic_parameters_after_n_nodes") &&
      parameter_dictionary->HasKey("min_score") &&
      parameter_dictionary->HasKey("max_score") &&
      parameter_dictionary->HasKey("score_steps") &&
      parameter_dictionary->HasKey("min_linear_search_window") &&
      parameter_dictionary->HasKey("max_linear_search_window") &&
      parameter_dictionary->HasKey("linear_search_window_steps") &&
      parameter_dictionary->HasKey("min_angular_search_window") &&
      parameter_dictionary->HasKey("max_angular_search_window") &&
      parameter_dictionary->HasKey("angular_search_window_steps")) {
    options.set_use_dynamic_parameters_after_n_nodes(
        parameter_dictionary->GetInt("use_dynamic_parameters_after_n_nodes"));
    options.set_min_score(parameter_dictionary->GetDouble("min_score"));
    options.set_max_score(parameter_dictionary->GetDouble("max_score"));
    options.set_score_steps(parameter_dictionary->GetInt("score_steps"));
    options.set_min_linear_search_window(
        parameter_dictionary->GetDouble("min_linear_search_window"));
    options.set_max_linear_search_window(
        parameter_dictionary->GetDouble("max_linear_search_window"));
    options.set_linear_search_window_steps(
        parameter_dictionary->GetInt("linear_search_window_steps"));
    options.set_min_angular_search_window(
        parameter_dictionary->GetDouble("min_angular_search_window"));
    options.set_max_angular_search_window(
        parameter_dictionary->GetDouble("max_angular_search_window"));
    options.set_angular_search_window_steps(
        parameter_dictionary->GetInt("angular_search_window_steps"));
  } else {
    options.set_use_dynamic_parameters_after_n_nodes(0);
    options.set_min_score(0.55);
    options.set_max_score(0.55);
    options.set_score_steps(1);
    options.set_min_linear_search_window(2);
    options.set_max_linear_search_window(5);
    options.set_linear_search_window_steps(1);
    options.set_min_angular_search_window(0.52);
    options.set_max_angular_search_window(0.52);
    options.set_angular_search_window_steps(1);
  }
  return options;
}

PrecomputationGrid2D::PrecomputationGrid2D(
    const Grid2D& grid, const CellLimits& limits, const int width,
    std::vector<float>* reusable_intermediate_grid)
    : offset_(-width + 1, -width + 1),
      wide_limits_(limits.num_x_cells + width - 1,
                   limits.num_y_cells + width - 1),
      min_score_(1.f - grid.GetMaxCorrespondenceCost()),
      max_score_(1.f - grid.GetMinCorrespondenceCost()),
      cells_(wide_limits_.num_x_cells * wide_limits_.num_y_cells) {
  CHECK_GE(width, 1);
  CHECK_GE(limits.num_x_cells, 1);
  CHECK_GE(limits.num_y_cells, 1);
  const int stride = wide_limits_.num_x_cells;
  // First we compute the maximum probability for each (x0, y) achieved in the
  // span defined by x0 <= x < x0 + width.
  std::vector<float>& intermediate = *reusable_intermediate_grid;
  intermediate.resize(wide_limits_.num_x_cells * limits.num_y_cells);
  for (int y = 0; y != limits.num_y_cells; ++y) {
    SlidingWindowMaximum current_values;
    current_values.AddValue(
        1.f - std::abs(grid.GetCorrespondenceCost(Eigen::Array2i(0, y))));
    for (int x = -width + 1; x != 0; ++x) {
      intermediate[x + width - 1 + y * stride] = current_values.GetMaximum();
      if (x + width < limits.num_x_cells) {
        current_values.AddValue(1.f - std::abs(grid.GetCorrespondenceCost(
                                          Eigen::Array2i(x + width, y))));
      }
    }
    for (int x = 0; x < limits.num_x_cells - width; ++x) {
      intermediate[x + width - 1 + y * stride] = current_values.GetMaximum();
      current_values.RemoveValue(
          1.f - std::abs(grid.GetCorrespondenceCost(Eigen::Array2i(x, y))));
      current_values.AddValue(1.f - std::abs(grid.GetCorrespondenceCost(
                                        Eigen::Array2i(x + width, y))));
    }
    for (int x = std::max(limits.num_x_cells - width, 0);
         x != limits.num_x_cells; ++x) {
      intermediate[x + width - 1 + y * stride] = current_values.GetMaximum();
      current_values.RemoveValue(
          1.f - std::abs(grid.GetCorrespondenceCost(Eigen::Array2i(x, y))));
    }
    current_values.CheckIsEmpty();
  }
  // For each (x, y), we compute the maximum probability in the width x width
  // region starting at each (x, y) and precompute the resulting bound on the
  // score.
  for (int x = 0; x != wide_limits_.num_x_cells; ++x) {
    SlidingWindowMaximum current_values;
    current_values.AddValue(intermediate[x]);
    for (int y = -width + 1; y != 0; ++y) {
      cells_[x + (y + width - 1) * stride] =
          ComputeCellValue(current_values.GetMaximum());
      if (y + width < limits.num_y_cells) {
        current_values.AddValue(intermediate[x + (y + width) * stride]);
      }
    }
    for (int y = 0; y < limits.num_y_cells - width; ++y) {
      cells_[x + (y + width - 1) * stride] =
          ComputeCellValue(current_values.GetMaximum());
      current_values.RemoveValue(intermediate[x + y * stride]);
      current_values.AddValue(intermediate[x + (y + width) * stride]);
    }
    for (int y = std::max(limits.num_y_cells - width, 0);
         y != limits.num_y_cells; ++y) {
      cells_[x + (y + width - 1) * stride] =
          ComputeCellValue(current_values.GetMaximum());
      current_values.RemoveValue(intermediate[x + y * stride]);
    }
    current_values.CheckIsEmpty();
  }
}

uint8 PrecomputationGrid2D::ComputeCellValue(const float probability) const {
  const int cell_value = common::RoundToInt(
      (probability - min_score_) * (255.f / (max_score_ - min_score_)));
  CHECK_GE(cell_value, 0);
  CHECK_LE(cell_value, 255);
  return cell_value;
}

PrecomputationGridStack2D::PrecomputationGridStack2D(
    const Grid2D& grid,
    const proto::FastCorrelativeScanMatcherOptions2D& options) {
  CHECK_GE(options.branch_and_bound_depth(), 1);
  const int max_width = 1 << (options.branch_and_bound_depth() - 1);
  precomputation_grids_.reserve(options.branch_and_bound_depth());
  std::vector<float> reusable_intermediate_grid;
  const CellLimits limits = grid.limits().cell_limits();
  reusable_intermediate_grid.reserve((limits.num_x_cells + max_width - 1) *
                                     limits.num_y_cells);
  for (int i = 0; i != options.branch_and_bound_depth(); ++i) {
    const int width = 1 << i;
    precomputation_grids_.emplace_back(grid, limits, width,
                                       &reusable_intermediate_grid);
  }
}

FastCorrelativeScanMatcher2D::FastCorrelativeScanMatcher2D(
    const Grid2D& grid,
    const proto::FastCorrelativeScanMatcherOptions2D& options)
    : options_(options),
      limits_(grid.limits()),
      precomputation_grid_stack_(
          absl::make_unique<PrecomputationGridStack2D>(grid, options)),
      max_matches_(options.score_steps() + 1) {
  fast_csm_dynamic_linear_search_window_ = options.min_linear_search_window();
  fast_csm_dynamic_angular_search_window_ = options.min_angular_search_window();
  branch_and_bound_start_time_ms_ = 1e-3 * common::GetCurrentTimeUs();
}

FastCorrelativeScanMatcher2D::~FastCorrelativeScanMatcher2D() {}

bool FastCorrelativeScanMatcher2D::Match(
    int match_type, const transform::Rigid2d& initial_pose_estimate,
    const sensor::PointCloud& point_cloud, float min_score, float* score,
    transform::Rigid2d* pose_estimate) const {
  transform::Rigid2d initial_pose = initial_pose_estimate;
  const double t0_ms = 1e-3 * common::GetCurrentTimeUs();
  bool match_success = false;

  if (match_type == constraints::GLOBAL_MAP_MATCH &&
      options_.use_dynamic_parameters_after_n_nodes() >= 1 &&
      use_dynamic_parameters_) {
    match_success = MatchWithDynamicParameters(
        initial_pose, point_cloud, score, pose_estimate,
        &fast_csm_dynamic_linear_search_window_,
        &fast_csm_dynamic_angular_search_window_);

    const double t1_ms = 1e-3 * common::GetCurrentTimeUs();
    if (t1_ms - t0_ms > 100) {
      LOG(WARNING) << "global match cost: " << t1_ms - t0_ms << " ms";
    }

    return match_success;

  } else {
    SearchParameters search_parameters;
    if (match_type == constraints::NORMAL_MATCH) {
      search_parameters = SearchParameters(options_.linear_search_window(),
                                           options_.angular_search_window(),
                                           point_cloud, limits_.resolution());
    } else if (match_type == constraints::FULL_SUBMAP_MATCH) {
      search_parameters = SearchParameters(1e6 * limits_.resolution(), M_PI,
                                           point_cloud, limits_.resolution());
      initial_pose = transform::Rigid2d::Translation(
          limits_.max() -
          0.5 * limits_.resolution() *
              Eigen::Vector2d(limits_.cell_limits().num_y_cells,
                              limits_.cell_limits().num_x_cells));
    } else if (match_type == constraints::GLOBAL_MAP_MATCH) {
      search_parameters = SearchParameters(options_.linear_search_window(),
                                           options_.angular_search_window(),
                                           point_cloud, limits_.resolution());
    }

    match_success =
        MatchWithSearchParameters(search_parameters, initial_pose, point_cloud,
                                  min_score, score, pose_estimate);

    const double t1_ms = 1e-3 * common::GetCurrentTimeUs();
    if (t1_ms - t0_ms > 300) {
      LOG(WARNING) << "global match cost: " << t1_ms - t0_ms << " ms";
    }

    return match_success;
  }
}

bool FastCorrelativeScanMatcher2D::MatchWithDynamicParameters(
    const transform::Rigid2d& initial_pose_estimate,
    const sensor::PointCloud& point_cloud, float* score,
    transform::Rigid2d* pose_estimate, double* dynamic_linear_search_window,
    double* dynamic_angular_search_window) const {
  bool match_success = false;
  SearchParameters search_parameters;
  const transform::Rigid2d delta_pose =
      last_scan_match_pose_.inverse() * initial_pose_estimate;
  traveled_after_last_constarint_ += delta_pose.translation().norm();
  traveled_angle_after_last_constarint_ +=
      std::fabs(delta_pose.rotation().angle());
  last_scan_match_pose_ = initial_pose_estimate;

  search_parameters = SearchParameters(*dynamic_linear_search_window,
                                       *dynamic_angular_search_window,
                                       point_cloud, limits_.resolution());
  match_success = MatchWithSearchParameters(
      search_parameters, initial_pose_estimate, point_cloud,
      options_.max_score(), score, pose_estimate);
  if (match_success) {
    traveled_after_last_constarint_ = 0.0;
    traveled_angle_after_last_constarint_ = 0.0;
    *dynamic_linear_search_window = options_.min_linear_search_window();
    *dynamic_angular_search_window = options_.min_angular_search_window();
    return true;
  } else {
    int match_times = 1;
    double dynamic_min_score = options_.max_score();

    if (traveled_after_last_constarint_ > (*dynamic_linear_search_window)) {
      const double new_linear_search_window =
          *dynamic_linear_search_window +
          (options_.max_linear_search_window() -
           options_.min_linear_search_window()) /
              options_.linear_search_window_steps();
      *dynamic_linear_search_window =
          new_linear_search_window > options_.max_linear_search_window()
              ? options_.max_linear_search_window()
              : new_linear_search_window;
    }

    if (traveled_angle_after_last_constarint_ >
        (*dynamic_angular_search_window)) {
      const double new_angular_search_window =
          *dynamic_angular_search_window +
          (options_.max_angular_search_window() -
           options_.min_angular_search_window()) /
              options_.angular_search_window_steps();
      *dynamic_angular_search_window =
          new_angular_search_window > options_.max_angular_search_window()
              ? options_.max_angular_search_window()
              : new_angular_search_window;
    }

    while (match_times < max_matches_ && !match_success) {
      dynamic_min_score =
          dynamic_min_score - (options_.max_score() - options_.min_score()) /
                                  options_.score_steps();

      search_parameters = SearchParameters(*dynamic_linear_search_window,
                                           *dynamic_angular_search_window,
                                           point_cloud, limits_.resolution());

      match_success = MatchWithSearchParameters(
          search_parameters, initial_pose_estimate, point_cloud,
          dynamic_min_score, score, pose_estimate);
      match_times++;
    }

    if (match_success) {
      traveled_after_last_constarint_ = 0.0;
      traveled_angle_after_last_constarint_ = 0.0;
      *dynamic_linear_search_window = options_.min_linear_search_window();
      *dynamic_angular_search_window = options_.min_angular_search_window();
      LOG_EVERY_N(WARNING, 20)
          << "Match success: score_search_window = [" << *score << ", "
          << *dynamic_linear_search_window << ", "
          << *dynamic_angular_search_window << "]";
    } else {
      LOG(WARNING) << "Match failed: score_search_window = [" << *score << ", "
                   << *dynamic_linear_search_window << ", "
                   << *dynamic_angular_search_window << "]";
    }

    return match_success;
  }
}

bool FastCorrelativeScanMatcher2D::MatchWithSearchParameters(
    SearchParameters search_parameters,
    const transform::Rigid2d& initial_pose_estimate,
    const sensor::PointCloud& point_cloud, float min_score, float* score,
    transform::Rigid2d* pose_estimate) const {
  CHECK(score != nullptr);
  CHECK(pose_estimate != nullptr);
  // const double t0_ms = 1e-3 * common::GetCurrentTimeUs();
  const Eigen::Rotation2Dd initial_rotation = initial_pose_estimate.rotation();
  const sensor::PointCloud rotated_point_cloud = sensor::TransformPointCloud(
      point_cloud,
      transform::Rigid3f::Rotation(Eigen::AngleAxisf(
          initial_rotation.cast<float>().angle(), Eigen::Vector3f::UnitZ())));
  const std::vector<sensor::PointCloud> rotated_scans =
      GenerateRotatedScans(rotated_point_cloud, search_parameters);
  const std::vector<DiscreteScan2D> discrete_scans = DiscretizeScans(
      limits_, rotated_scans,
      Eigen::Translation2f(initial_pose_estimate.translation().x(),
                           initial_pose_estimate.translation().y()));
  search_parameters.ShrinkToFit(discrete_scans, limits_.cell_limits());

  const std::vector<Candidate2D> lowest_resolution_candidates =
      ComputeLowestResolutionCandidates(discrete_scans, search_parameters);

  branch_and_bound_start_time_ms_ = 1e-3 * common::GetCurrentTimeUs();

  const Candidate2D best_candidate = BranchAndBound(
      discrete_scans, search_parameters, lowest_resolution_candidates,
      precomputation_grid_stack_->max_depth(), min_score);

  if (branch_and_bound_time_cost_ms_ >
          branch_and_bound_time_cost_threshold_ms_ &&
      is_trajectory_init_) {
    LOG(WARNING) << "BranchAndBound  compute cost "
                 << branch_and_bound_time_cost_ms_
                 << "ms,  calculation will be terminated!";
  }

  *score = best_candidate.score;

  if (best_candidate.score > min_score) {
    //*score = best_candidate.score;
    *pose_estimate = transform::Rigid2d(
        {initial_pose_estimate.translation().x() + best_candidate.x,
         initial_pose_estimate.translation().y() + best_candidate.y},
        initial_rotation * Eigen::Rotation2Dd(best_candidate.orientation));
    return true;
  }
  return false;
}

std::vector<Candidate2D>
FastCorrelativeScanMatcher2D::ComputeLowestResolutionCandidates(
    const std::vector<DiscreteScan2D>& discrete_scans,
    const SearchParameters& search_parameters) const {
  std::vector<Candidate2D> lowest_resolution_candidates =
      GenerateLowestResolutionCandidates(search_parameters);
  ScoreCandidates(
      precomputation_grid_stack_->Get(precomputation_grid_stack_->max_depth()),
      discrete_scans, search_parameters, &lowest_resolution_candidates);
  return lowest_resolution_candidates;
}

std::vector<Candidate2D>
FastCorrelativeScanMatcher2D::GenerateLowestResolutionCandidates(
    const SearchParameters& search_parameters) const {
  const int linear_step_size = 1 << precomputation_grid_stack_->max_depth();
  int num_candidates = 0;
  for (int scan_index = 0; scan_index != search_parameters.num_scans;
       ++scan_index) {
    const int num_lowest_resolution_linear_x_candidates =
        (search_parameters.linear_bounds[scan_index].max_x -
         search_parameters.linear_bounds[scan_index].min_x + linear_step_size) /
        linear_step_size;
    const int num_lowest_resolution_linear_y_candidates =
        (search_parameters.linear_bounds[scan_index].max_y -
         search_parameters.linear_bounds[scan_index].min_y + linear_step_size) /
        linear_step_size;
    num_candidates += num_lowest_resolution_linear_x_candidates *
                      num_lowest_resolution_linear_y_candidates;
  }
  std::vector<Candidate2D> candidates;
  candidates.reserve(num_candidates);
  for (int scan_index = 0; scan_index != search_parameters.num_scans;
       ++scan_index) {
    for (int x_index_offset = search_parameters.linear_bounds[scan_index].min_x;
         x_index_offset <= search_parameters.linear_bounds[scan_index].max_x;
         x_index_offset += linear_step_size) {
      for (int y_index_offset =
               search_parameters.linear_bounds[scan_index].min_y;
           y_index_offset <= search_parameters.linear_bounds[scan_index].max_y;
           y_index_offset += linear_step_size) {
        candidates.emplace_back(scan_index, x_index_offset, y_index_offset,
                                search_parameters);
      }
    }
  }
  CHECK_EQ(candidates.size(), num_candidates);
  return candidates;
}

void FastCorrelativeScanMatcher2D::ScoreCandidates(
    const PrecomputationGrid2D& precomputation_grid,
    const std::vector<DiscreteScan2D>& discrete_scans,
    const SearchParameters& search_parameters,
    std::vector<Candidate2D>* const candidates) const {
  for (Candidate2D& candidate : *candidates) {
    int sum = 0;
    for (const Eigen::Array2i& xy_index :
         discrete_scans[candidate.scan_index]) {
      const Eigen::Array2i proposed_xy_index(
          xy_index.x() + candidate.x_index_offset,
          xy_index.y() + candidate.y_index_offset);
      sum += precomputation_grid.GetValue(proposed_xy_index);
    }
    candidate.score = precomputation_grid.ToScore(
        sum / static_cast<float>(discrete_scans[candidate.scan_index].size()));
  }
  std::sort(candidates->begin(), candidates->end(),
            std::greater<Candidate2D>());
}

Candidate2D FastCorrelativeScanMatcher2D::BranchAndBound(
    const std::vector<DiscreteScan2D>& discrete_scans,
    const SearchParameters& search_parameters,
    const std::vector<Candidate2D>& candidates, const int candidate_depth,
    float min_score) const {
  if (candidate_depth == 0) {
    // Return the best candidate.
    return *candidates.begin();
  }
  Candidate2D best_high_resolution_candidate(0, 0, 0, search_parameters);

  branch_and_bound_time_cost_ms_ =
      1e-3 * common::GetCurrentTimeUs() - branch_and_bound_start_time_ms_;
  if (branch_and_bound_time_cost_ms_ >
          branch_and_bound_time_cost_threshold_ms_ &&
      is_trajectory_init_) {
    return best_high_resolution_candidate;
  }

  best_high_resolution_candidate.score = min_score;
  for (const Candidate2D& candidate : candidates) {
    if (candidate.score <= min_score) {
      break;
    }
    std::vector<Candidate2D> higher_resolution_candidates;
    const int half_width = 1 << (candidate_depth - 1);
    for (int x_offset : {0, half_width}) {
      if (candidate.x_index_offset + x_offset >
          search_parameters.linear_bounds[candidate.scan_index].max_x) {
        break;
      }
      for (int y_offset : {0, half_width}) {
        if (candidate.y_index_offset + y_offset >
            search_parameters.linear_bounds[candidate.scan_index].max_y) {
          break;
        }
        higher_resolution_candidates.emplace_back(
            candidate.scan_index, candidate.x_index_offset + x_offset,
            candidate.y_index_offset + y_offset, search_parameters);
      }
    }
    ScoreCandidates(precomputation_grid_stack_->Get(candidate_depth - 1),
                    discrete_scans, search_parameters,
                    &higher_resolution_candidates);
    best_high_resolution_candidate = std::max(
        best_high_resolution_candidate,
        BranchAndBound(discrete_scans, search_parameters,
                       higher_resolution_candidates, candidate_depth - 1,
                       best_high_resolution_candidate.score));
  }
  return best_high_resolution_candidate;
}

}  // namespace scan_matching
}  // namespace mapping
}  // namespace slam
